import os
import tensorflow as tf
import json

def ini_net(initializer_method):
    if initializer_method == 'Zeros':
        initializer = tf.keras.initializers.Zeros()
    elif initializer_method == 'Ones':
        initializer = tf.keras.initializers.Ones()
    elif initializer_method == 'GlorotNormal':
        initializer = tf.keras.initializers.GlorotNormal(seed=None)
    elif initializer_method == 'GlorotUniform':
        initializer = tf.keras.initializers.GlorotUniform(seed=None)
    elif initializer_method == 'HeNormal':
        initializer = tf.keras.initializers.HeNormal(seed=None)
    elif initializer_method == 'HeUniform':
        initializer = tf.keras.initializers.HeUniform(seed=None)
    elif initializer_method == 'RandomNormal':
        initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.05, seed=None)
    elif initializer_method == 'RandomUniform':
        initializer = tf.keras.initializers.RandomUniform(minval=-0.05, maxval=0.05, seed=None)
    elif initializer_method == 'None':
        initializer = None

    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Dense(64, activation='tanh', kernel_initializer=initializer))
    model.add(tf.keras.layers.Dense(64, activation='tanh', kernel_initializer=initializer))

    return model


def save_missing_arguments(agent, dir):
    """
    When tensorforce save a model with keras network, some arguments cannot be saved. Here we manually save them.
    The configuration for each argument is directly copied from tensorforce's PPO file on GitHub.
    """
    update = dict(unit='episodes', batch_size=agent.spec['batch_size'], frequency=agent.spec['update_frequency'])
    optimizer = dict(
        optimizer='adam', learning_rate=agent.spec['learning_rate'], multi_step=agent.spec['multi_step'],
        subsampling_fraction=agent.spec['subsampling_fraction']
    )
    objective = dict(
        type='policy_gradient', importance_sampling=True,
        clipping_value=agent.spec['likelihood_ratio_clipping']
    )
    if agent.spec['baseline'] is None:
        assert not agent.spec['predict_terminal_values']
        reward_estimation = dict(
            horizon='episode', discount=agent.spec['discount'], predict_horizon_values=False,
            estimate_advantage=False, reward_processing=None,
            return_processing=agent.spec['return_processing']
        )
        assert agent.spec['baseline_optimizer'] is None
        baseline_objective = None

    else:
        reward_estimation = dict(
            horizon='episode', discount=agent.spec['discount'], predict_horizon_values='early',
            estimate_advantage=True, predict_action_values=False,
            reward_processing=None, return_processing=agent.spec['return_processing'],
            advantage_processing=agent.spec['advantage_processing'],
            predict_terminal_values=agent.spec['predict_terminal_values']
        )
        baseline = dict(type='parametrized_state_value', network=agent.spec['baseline'])
        assert agent.spec['baseline_optimizer'] is not None
        baseline_objective = dict(type='state_value')

    if agent.spec['memory'] == 'minimum':
        memory = dict(type='recent')
    else:
        memory = dict(type='recent', capacity=agent.spec['memory'])

    # Serialize data into file
    json.dump(update, open(os.path.join(dir,'update.json'), 'w'))
    json.dump(optimizer, open(os.path.join(dir,'optimizer.json'), 'w'))
    json.dump(objective, open(os.path.join(dir,'objective.json'), 'w'))
    json.dump(reward_estimation, open(os.path.join(dir,'reward_estimation.json'), 'w'))
    json.dump(memory, open(os.path.join(dir,'memory.json'), 'w'))


def load_missing_arguments(dir):
    # Read data from file
    update = json.load(open(os.path.join(dir, 'update.json')))
    optimizer = json.load(open(os.path.join(dir, 'optimizer.json')))
    objective = json.load(open(os.path.join(dir, 'objective.json')))
    reward_estimation = json.load(open(os.path.join(dir, 'reward_estimation.json')))
    memory = json.load(open(os.path.join(dir, 'memory.json')))

    return update, optimizer, objective, reward_estimation, memory
